import scalevi.var_dists.var_dists_branched as var_dists_branched
import scalevi.var_dists.var_dists as var_dists
import scalevi.distributions.scale_transforms as scale_transforms
import scalevi.utils.utils as utils
import scalevi.dataloader as dataloader
import jax
import jax.numpy as np

def initial_params_args(config_dict):
    args = {}
    if config_dict['var_dist'] in [
                                    "AmortizedBranchGaussian", 
                                    "AmortizedBranchGaussianWithSampleEval",
                                    "AmortizedBranchDiagonalWithSampleEval",
                                    "AmortizedBranchBlockGaussianWithSampleEval"
                                    ]:
        args.update({
                "rngs": jax.random.PRNGKey(0),
                "θ": np.ones(config_dict.get('D_par', 1)),
            })

        if config_dict['encoder'] in [  'Encoder', 
                                        "MaskedEncoder", 
                                        "MaskedEncoder_v2",
                                        "MaskedEncoder_vGaussian",
                                        "MaskedEncoder_vBlockGaussian",
                                        "MaskedEncoder_vDiagonal",
                                        ]:
            D = config_dict['D_data']
            args.update({
                    "x": np.ones((config_dict["M_data"], D)),
                    "θ": np.ones(D),
                })
            if config_dict['model'] in ["BranchConditionalGaussian",
                                        "BranchConditionalBernoulli",
                                        "RS"]:
                args.update({
                    "x": np.ones((config_dict["M_data"], D+1)),
                })
                
            if config_dict['model'] in ["RS"]:
                args.update({
                    "θ": np.ones(D + (D+1)*D//2),
                })

            # if config_dict['use_mask'] is True:
            args.update({
                "mask" : np.ones(config_dict['M_data']),
            })

        elif config_dict['encoder'] in ['Bilinear', 'BilinearMean']:
            args.update({
                "x": np.ones(config_dict["N_leaves"]),
                "θ": np.ones(config_dict.get('D_par', 1)) 
            })
            if config_dict['model'] in ["BranchGaussianWithData",
                                        "BranchConditionalGaussian",
                                        "BranchConditionalBernoulli",
                                        ]:
                args.update({
                    "θ": np.ones(config_dict['D_data']) 
                })
            elif config_dict['model'] in ["RS"]:
                D = config_dict['D_data']
                args.update({
                    "θ": np.ones(D + (D+1)*D//2),
                })
    return args


def get_var_dist_args(config_dict):
    args = {}
    ################################################
    # Scale Tranform arguments
    ################################################
    scale_transform = getattr(
                scale_transforms,
                config_dict.get("scale_transform", "Proximal")
                + "ScaleTransform")

    if config_dict.get('scale_transform', "Proximal")=="Proximal":
        args['scale_transform'] = scale_transform(
                                    config_dict.get(
                                        "scale_transform_gamma", 1.0))

    else:
        args['scale_transform'] = scale_transform()
        
    ################################################
    # Branch Gaussian arguments
    ################################################

    args.update({
        "N_chunk": config_dict['N_leaves'],
        "D_par": config_dict.get('D_par', None),
        "D_kid": config_dict.get('D_kid', None),
        })

    if config_dict['model'] in [
                                "BranchGaussian",
                                # "BranchGaussianWithSampleEval",
                                "BranchGaussianWithData",
                                "BranchConditionalGaussian",
                                "BranchConditionalBernoulli",]:
        args.update({
            "D_par": config_dict['D_data'],
            "D_kid": config_dict['D_data'],
        })
        
    elif config_dict['model'] in ["RS"]:
        D = config_dict['D_data']
        args.update({
            "D_par": D + (D+1)*D//2,
            "D_kid": D,
        })

    else:
        raise NotImplementedError(
                            "Define the relationship between the "
                            "D_par and D_kid that for the Branch "
                            "variational distribution.")

    ################################################
    # Encoder arguments
    ################################################
    if "Amortize" in config_dict['var_dist']:
        args.update({
                "encoder": config_dict['encoder'],
                "model": config_dict['model'],
                "masked_model": config_dict.get('use_mask', False)
            })
        if config_dict['encoder'] in ['Encoder', "MaskedEncoder"]:
            D_data = config_dict['D_data']
            args.update({
                "data": dataloader.get_data(config_dict, False),
                "fcn_features": config_dict['encoder_fcn_features'],
                "encoding_features": config_dict['encoder_E'],
                "bl_features": D_data + (D_data+1)*D_data//2,
                "keep_sum_stats": config_dict.get('encoder_keep_sum_stats', False),
                "split": D_data,
            })

        elif config_dict['encoder'] in [
                                        'MaskedEncoder_v2',
                                        'MaskedEncoder_vGaussian',
                                        'MaskedEncoder_vBlockGaussian',
                                        'MaskedEncoder_vDiagonal',
                                        ]:
            D_data = config_dict['D_data']
            args.update({
                "data": dataloader.get_data(config_dict, False),
                "fcn_1_features": config_dict['encoder_fcn_1_features'],
                "fcn_2_features": config_dict['encoder_fcn_2_features'],
                "encoding_features": config_dict['encoder_E'],
                "keep_sum_stats": config_dict.get('encoder_keep_sum_stats', False),
                "D_z": D_data,
                # "D_θ": D_data + D_data*(D_data+1)//2,
                "D_θ": args['D_par'],
                "encode_θ": config_dict.get('encoder_encode_θ', False),
            })

        elif config_dict['encoder'] in ['Bilinear', 'BilinearMean']:
            D_data = config_dict['D_data']
            args.update({
                "data": {'x': np.eye(config_dict['N_leaves'])},
                })

            if config_dict['encoder']=="Bilinear":
                args.update({
                    "bl_features": D_data + (D_data+1)*D_data//2,
                    "split": D_data,
                    })

            elif config_dict['encoder']=="BilinearMean":
                args.update({
                    "bl_features": D_data,
                    "features": (D_data+1)*D_data//2,
                    })

    if config_dict['var_dist'] == "LowRankGauss":
        args.update({'r': config_dict['r']})

    return args


def get_var_dist(config_dict):
    VarDist = utils.get_attribute([var_dists, var_dists_branched], config_dict['var_dist'])
    return VarDist(**get_var_dist_args(config_dict))

